#!/usr/bin/env python
# _*_ coding=utf-8 _*_

from numpy import *
from Polymode import *
from Polymode import MPISolver

def choose_results_file():
	import re, os

	savedir = 'results'
	modename = re.compile("^(.*)\.queue$")
	files = filter(modename.match, os.listdir(savedir))
	for n,f in enumerate(files):
		print "%d) %s" %(n,modename.match(f).groups()[0])
	#fselect = int(raw_input("Display which files? "))
	fselect = array(map(int,raw_input("Display which files? ").split()))
	filenames = [os.path.join(savedir,files[ii]) for ii in fselect]
	return filenames

def calculate_first_intersect(x,y,intersect, direction=1):
	from scipy import interpolate, optimize
	#Sort x and y
	isort = argsort(x)[::direction]; x=x[isort]; y=y[isort]
	ifind = squeeze(nonzero(direction*y<direction*intersect))
	if len(ifind)>0:
		ifirst = ifind[0]
	else:
		ifirst = len(y)-1
	#Linear interpolation
	ix = (y[ifirst-1]-intersect)/(y[ifirst-1]-y[ifirst])*(x[ifirst]-x[ifirst-1]) + x[ifirst-1]
	return ix, ifirst

def calculate_median_intersect(x,y,intersect, direction=1, nrun=5):
	from scipy import interpolate, optimize

	#Sort x and y
	isort = argsort(x)[::direction]; x=x[isort]; y=y[isort]
	ysearch = [median(y[ii:ii+nrun])>intersect for ii in range(len(x)-nrun)]
	ifind = ravel(nonzero(ysearch)) + nrun//2

	if len(ifind)>0:
		ifirst = ifind[0]
	else:
		ifirst = len(y)-1
	#Linear interpolation
	ix = (y[ifirst-1]-intersect)/(y[ifirst-1]-y[ifirst])*(x[ifirst]-x[ifirst-1]) + x[ifirst-1]
	return ix, ifirst

def calculate_mode_cutoff(neffs, cutoff=100):
	loss_all = 2e7*imag(neffs)/log(10)
	if loss_all.max()<cutoff:
		print "Max loss less than cutoff"
		return real(neffs).min(),-1
	return calculate_median_intersect(real(neffs), loss_all, cutoff, direction=-1)
	#return calculate_first_intersect(real(neffs), loss_all, cutoff, direction=-1)	

def te_planar_evalue(nex, n0, ns, k, d):
	gs = sqrt(ns**2 - nex**2)*k
	g0 = sqrt(nex**2 - n0**2)*k
	return (g0/gs - gs/g0)*sin(d*gs) + 2*cos(d*gs)

def tm_planar_evalue(nex, n0, ns, k, d):
	gs = sqrt(ns**2 - nex**2)*k
	g0 = sqrt(nex**2 - n0**2)*k
	return ((ns/n0)**2*g0/gs - (n0/ns)**2*gs/g0)*sin(d*gs) + 2*cos(d*gs)

def find_planar_modes(nclad, ncore, d, k, napprox=None):
	if napprox is None: napprox = nclad
	from scipy import optimize
	
	teneff = optimize.fsolve(te_planar_evalue, args=(nclad, ncore, k, d), x0=napprox)
	tmneff = optimize.fsolve(tm_planar_evalue, args=(nclad, ncore, k, d), x0=napprox)
	return teneff,tmneff


if __name__ == "__main__":
	#Parse options even as an add-on module!
	from optparse import OptionParser,Values
	import datetime
	import cPickle as pickle

	usage = "usage: %prog [options] [filename]"
	parser = OptionParser(usage)

	parser.add_option("-s", "--status",
		action="store_true", dest="status", default=False,
		help="queue status")

	parser.add_option("-j", "--job",
		action="store_true", dest="jobinfo", default=False,
		help="display information about a job")

	parser.add_option("-i", "--info",
		action="store_true", dest="modeinfo", default=False,
		help="display information on modes calculated")

	parser.add_option("-c", "--compress",
		action="store_true", dest="compress", default=False,
		help="remove vector information from modes")

	parser.add_option("-p", "--plot",
		action="store", dest="plot", default=None,
		help="plot the modes either as [t]ransmission, [n]eff or [l]oss")

	parser.add_option("-e", "--edit",
		action="store_true", dest="edit", default=False,
		help="edit queue")
	
	options, args = parser.parse_args()
	
	if len(args)==0:
		filenames = choose_results_file()
	elif len(args)==1:
		filenames = args
	else:
		parser.error("Expected filename got %d arguments" % len(args))

	resave_queue=False
	loadmpi=True

	if loadmpi:
		queue = []
		for filename in filenames:
			q = pickle.load(open(filename,'rb'))

			#Add filename to queue
			for item in q:
				item.solver.info['q'] = filename
			print "Loaded MPI queue of %d items from %s" % (len(q), filename)
			if hasattr(q,'comment'):
				print '"%s"\n' % q.comment
			queue += q

		#Find all different tags to plot
		unique_tags = []
		for item in queue:
			if item.solver.info not in unique_tags:
				unique_tags += [item.solver.info]

	if options.status or options.edit:
		for n,item in enumerate(queue):
			time_taken = "?"
			if item.started_at is not None:
				if item.finished_at is None:
					time_taken =  datetime.datetime.now()-item.started_at
				else:
					time_taken =  item.finished_at-item.started_at

			percentage = item.solver.percent_complete()
			status = MPISolver.Status.asstring(item.status)
			numberdone = len(item.solver.modes)
	
			if item.isfinished():
				print "[%d] : %s, Modes: %d, Time to solve: %s" % (n, status, numberdone, time_taken)
			elif not item.ispending():
				print "[%d] : %s, Modes: %d, Complete: %.2g%% Time taken to now: %s" \
					% (n, status, numberdone, percentage, time_taken)
		ncomplete = sum([i.iscompleted() for i in queue])
		nfailed = sum([i.isfailed() for i in queue])
		npending = sum([i.ispending() for i in queue])
		print "Found %d queue items, %d complete, %d failed, %d pending" \
			% (len(queue), ncomplete, nfailed, npending)

	if options.jobinfo:
		print
		itemnum = int(raw_input("Display which queue item? "))
		while itemnum>-1:
			item = queue[itemnum]
			print "%s" % item
			print item.info()

			itemnum = int(raw_input("Display which queue item? "))

	if options.compress:
		for item in queue:
			if hasattr(item.solver.equation, 'cinx'):
				del item.solver.equation.cinx
				
			if len(item.solver.modes)>0:
				for m in item.solver.modes:
					m.right = m.left = None
		resave_queue=True

	if options.edit:
		print
		itemnum = int(raw_input("Edit which queue item? "))
		while itemnum>-1:
			item = queue[itemnum]

			print "%s" % item
			print item.info()
			#print "Edit what: [s]tatus"
		
			if True:
				status_dict = MPISolver.Status.listall()
				print ", ".join(map(lambda s: "%s: %d" % (s[1],s[0]), status_dict.items()))
				sselect = int(raw_input("Enter new status "))
				item.status = sselect

			print "Modified item: %s" % (item)
			print
			itemnum = int(raw_input("Edit which queue item? "))
		
		resave_queue=True
		
	if options.plot:
		from pylab import *
		f1 = figure(); a1 = axes(); a2 = None
		f2 = None
		print 
		for series,tag in enumerate(unique_tags):
			#if options.which is not None and series not in options.which:
			#	continue
			qx = filter(lambda i: i.solver.info==tag, queue)

			modes = []
			for i in qx: modes += i.solver.get_data()
			neffs = [m.neff for m in modes]
		
			if 'n' in options.plot:
				a1.plot(real(neffs), imag(neffs), 'o')
				a1.set_xlabel(r'Re$(n_{neff})$')
				a1.set_ylabel(r'Im$(n_{neff})$')

			if 'd' in options.plot:
				bins = 25
				
				#Plot histogram on new axis
				dneff = (real(neffs).max()-real(neffs).min())/bins
				if a2 is None: a2 = twinx()

				nmodes, nav = histogram(real(neffs), bins=bins)
				a2.plot(nav,nmodes, linestyle='steps-post', color='black', alpha=0.2)
				a2.set_ylabel("Mode density")
				print "Histogram dneff=%.4g" % dneff

			if 'c' in options.plot:
				neffx, Ncut = calculate_mode_cutoff(neffs, cutoff=100)
				yb = a1.get_ybound()
				xb = a1.get_xbound()
				
				a1.plot([neffx, neffx], yb, 'r:')
				a1.text(neffx+(xb[1]-xb[0])*0.01, 0.9*yb[1]+yb[0]*0.1, "Cutoff=%.5g" % neffx)


			if 'g' in options.plot:
				if f2 is None:
					f2 = figure()
				
					gridnumber = 1
					rownum = floor(sqrt(len(unique_tags)))
					colnum = ceil(len(unique_tags)/rownum)
				
				f2.add_subplot(rownum,colnum,gridnumber)
				wg = qx[0].solver.wg
				
				nsec=5
				ext = wg.extents()
				wg.plot(sectors=nsec, fc='black')
				axis([ext[0]-5, ext[1]+5, ext[1]*tan(ext[2])*1.5, \
					ext[1]*tan((ext[3]-ext[2])*nsec)])

			if 'v' in options.plot:
				c = modes[0].coord.new(Nshape=(50, 3*modes[0].coord.symmetry), symmetry=1)

				ngs = []
				for m in modes:
					if hasattr(m, 'calculated_quantities'):
						ngs.append( physcon.c/m.calculated_quantities['Vg'] )
					elif m.right is not None:
						ngs.append( m.group_index(qx[0].solver.wg, coord=c) )
				
				a1.plot(real(neffs), real(ngs), '.')
				a1.set_xlabel(r'$n_{neff}$')
				a1.set_ylabel(r'$n_{g}$')

		show()

	if options.modeinfo:
		from scipy import stats
		for series,tag in enumerate(unique_tags):
			#if options.which is not None and series not in options.which:
			#	continue
			qx = filter(lambda i: i.solver.info==tag, queue)

			modes = []
			for i in qx:
				modes += i.solver.get_data()

			modes = Modes.filter_unique_modes(modes, by='neff')
			neffs = array([m.neff for m in modes])
			
			if len(modes)<1:
				continue
			
			wg = i.solver.wg
			k0 = i.solver.k0
			ncore = wg.index0(k0)
			nclad = 1
			
			m0s = unique([m.m0 for m in modes])
			for m0 in m0s:
				modes_m0 = array(filter(lambda m: m.m0==m0, modes))
				neffs_m0 = array([m.neff for m in modes_m0])
				Nm = len(modes_m0)
				
				#Number of modes to cutoff
				neffx, Ncut = calculate_mode_cutoff(neffs_m0, cutoff=100)
				NA = sqrt(i.solver.wg.index0(i.solver.k0)**2 - neffx**2)
				av_mode_density = Ncut/(real(neffs_m0).max()-real(neffx))

				print "\tm0=%d, N_M=%d, Ncutoff=%d, Mode density=%.2f" \
					% (m0,Nm,Ncut,av_mode_density)
				print "\t | Cutoff: %.5g, NA from cutoff: %.4g" % (neffx, NA)
			print "-"*40

			#Number of modes to cutoff
			neffx, Ncut = calculate_mode_cutoff(neffs, cutoff=100)
			NA = sqrt(ncore**2 - neffx**2)

			print "Cutoff: %.5g, NA from cutoff: %.4g" % (neffx, NA)
			print "Number of modes before cutoff: %d" % Ncut
			
			dplanar = None
			if 'd' in tag: dplanar = tag['d']
			if 'd2' in tag: dplanar = tag['d2']
			if dplanar is not None:
				bridge_neffs = find_planar_modes(nclad, ncore, dplanar, k0, 1.44)
				
				print "TE planar mode for %.2g, neff:%.5g" \
					% (dplanar,bridge_neffs[0])
				print "TM planar mode for %.2g, neff:%.5g" \
					% (dplanar,bridge_neffs[1])
			
			isort = real(neffs).argsort()
			modes_before_cutoff = array(modes)[isort[-Ncut:]]
			neffs_before_cutoff = array(neffs)[isort[-Ncut:]]

			#Calculate group velocities
			coord = modes[0].coord
			c = coord.new(Nshape=(50, 3*coord.symmetry), symmetry=1)
			vgs = []
			for m in modes_before_cutoff:
				if hasattr(m, 'calculated_quantities'):
					vgs.append( m.calculated_quantities['Vg'] )
				elif m.right is not None:
					vgs.append( m.group_velocity(wg, coord=c) )
			
			slope,y0 = stats.linregress(real(neffs_before_cutoff), real(vgs))[0:2]
			vgs_min = slope*neffx + y0
			vgs_max = slope*real(neffs_before_cutoff[-1]) + y0
			
			from pylab import *
			plot(neffs_before_cutoff, vgs, 'r.')
			plot(neffs_before_cutoff, array(neffs_before_cutoff)*slope + y0, 'b-')

			deltat = real(1/vgs_min - 1/vgs_max)
			print "δt (ns) = ", deltat*1e9
			print "Bandwidth (GHz) Le>L (L=100m) = ", 1e-9/(2*100*deltat)
			print "Bandwidth (GHz) Le=1 (L=100m) = ", 1e-9/(2*sqrt(100)*deltat)
			
			print "-"*40
			print

	if resave_queue:
		print "Saving modified queue as %s.new" % filename
		pickle.dump(queue, open(filename+".new",'wb'), protocol=-1)


